| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
|
|
| import math |
| from typing import TYPE_CHECKING |
|
|
| import torch |
|
|
| from ..configuration_utils import register_to_config |
| from ..utils import is_kornia_available |
| from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg |
|
|
|
|
| if TYPE_CHECKING: |
| from ..modular_pipelines.modular_pipeline import BlockState |
|
|
|
|
| _CAN_USE_KORNIA = is_kornia_available() |
|
|
|
|
| if _CAN_USE_KORNIA: |
| from kornia.geometry import pyrup as upsample_and_blur_func |
| from kornia.geometry.transform import build_laplacian_pyramid as build_laplacian_pyramid_func |
| else: |
| upsample_and_blur_func = None |
| build_laplacian_pyramid_func = None |
|
|
|
|
| def project(v0: torch.Tensor, v1: torch.Tensor, upcast_to_double: bool = True) -> tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Project vector v0 onto vector v1, returning the parallel and orthogonal components of v0. Implementation from paper |
| (Algorithm 2). |
| """ |
| |
| |
| |
| all_dims_but_first = list(range(1, len(v0.shape))) |
| if upcast_to_double: |
| dtype = v0.dtype |
| v0, v1 = v0.double(), v1.double() |
| v1 = torch.nn.functional.normalize(v1, dim=all_dims_but_first) |
| v0_parallel = (v0 * v1).sum(dim=all_dims_but_first, keepdim=True) * v1 |
| v0_orthogonal = v0 - v0_parallel |
| if upcast_to_double: |
| v0_parallel = v0_parallel.to(dtype) |
| v0_orthogonal = v0_orthogonal.to(dtype) |
| return v0_parallel, v0_orthogonal |
|
|
|
|
| def build_image_from_pyramid(pyramid: list[torch.Tensor]) -> torch.Tensor: |
| """ |
| Recovers the data space latents from the Laplacian pyramid frequency space. Implementation from the paper |
| (Algorithm 2). |
| """ |
| |
| img = pyramid[-1] |
| for i in range(len(pyramid) - 2, -1, -1): |
| img = upsample_and_blur_func(img) + pyramid[i] |
| return img |
|
|
|
|
| class FrequencyDecoupledGuidance(BaseGuidance): |
| """ |
| Frequency-Decoupled Guidance (FDG): https://huggingface.co/papers/2506.19713 |
| |
| FDG is a technique similar to (and based on) classifier-free guidance (CFG) which is used to improve generation |
| quality and condition-following in diffusion models. Like CFG, during training we jointly train the model on both |
| conditional and unconditional data, and use a combination of the two during inference. (If you want more details on |
| how CFG works, you can check out the CFG guider.) |
| |
| FDG differs from CFG in that the normal CFG prediction is instead decoupled into low- and high-frequency components |
| using a frequency transform (such as a Laplacian pyramid). The CFG update is then performed in frequency space |
| separately for the low- and high-frequency components with different guidance scales. Finally, the inverse |
| frequency transform is used to map the CFG frequency predictions back to data space (e.g. pixel space for images) |
| to form the final FDG prediction. |
| |
| For images, the FDG authors found that using low guidance scales for the low-frequency components retains sample |
| diversity and realistic color composition, while using high guidance scales for high-frequency components enhances |
| sample quality (such as better visual details). Therefore, they recommend using low guidance scales (low w_low) for |
| the low-frequency components and high guidance scales (high w_high) for the high-frequency components. As an |
| example, they suggest w_low = 5.0 and w_high = 10.0 for Stable Diffusion XL (see Table 8 in the paper). |
| |
| As with CFG, Diffusers implements the scaling and shifting on the unconditional prediction based on the [Imagen |
| paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original CFG paper proposed in |
| theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)] |
| |
| The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the |
| paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time. |
| |
| Args: |
| guidance_scales (`list[float]`, defaults to `[10.0, 5.0]`): |
| The scale parameter for frequency-decoupled guidance for each frequency component, listed from highest |
| frequency level to lowest. Higher values result in stronger conditioning on the text prompt, while lower |
| values allow for more freedom in generation. Higher values may lead to saturation and deterioration of |
| image quality. The FDG authors recommend using higher guidance scales for higher frequency components and |
| lower guidance scales for lower frequency components (so `guidance_scales` should typically be sorted in |
| descending order). |
| guidance_rescale (`float` or `list[float]`, defaults to `0.0`): |
| The rescale factor applied to the noise predictions. This is used to improve image quality and fix |
| overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are |
| Flawed](https://huggingface.co/papers/2305.08891). If a list is supplied, it should be the same length as |
| `guidance_scales`. |
| parallel_weights (`float` or `list[float]`, *optional*): |
| Optional weights for the parallel component of each frequency component of the projected CFG shift. If not |
| set, the weights will default to `1.0` for all components, which corresponds to using the normal CFG shift |
| (that is, equal weights for the parallel and orthogonal components). If set, a value in `[0, 1]` is |
| recommended. If a list is supplied, it should be the same length as `guidance_scales`. |
| use_original_formulation (`bool`, defaults to `False`): |
| Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default, |
| we use the diffusers-native implementation that has been in the codebase for a long time. See |
| [~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details. |
| start (`float` or `list[float]`, defaults to `0.0`): |
| The fraction of the total number of denoising steps after which guidance starts. If a list is supplied, it |
| should be the same length as `guidance_scales`. |
| stop (`float` or `list[float]`, defaults to `1.0`): |
| The fraction of the total number of denoising steps after which guidance stops. If a list is supplied, it |
| should be the same length as `guidance_scales`. |
| guidance_rescale_space (`str`, defaults to `"data"`): |
| Whether to performance guidance rescaling in `"data"` space (after the full FDG update in data space) or in |
| `"freq"` space (right after the CFG update, for each freq level). Note that frequency space rescaling is |
| speculative and may not produce expected results. If `"data"` is set, the first `guidance_rescale` value |
| will be used; otherwise, per-frequency-level guidance rescale values will be used if available. |
| upcast_to_double (`bool`, defaults to `True`): |
| Whether to upcast certain operations, such as the projection operation when using `parallel_weights`, to |
| float64 when performing guidance. This may result in better performance at the cost of increased runtime. |
| """ |
|
|
| _input_predictions = ["pred_cond", "pred_uncond"] |
|
|
| @register_to_config |
| def __init__( |
| self, |
| guidance_scales: list[float] | tuple[float] = [10.0, 5.0], |
| guidance_rescale: float | list[float] | tuple[float] = 0.0, |
| parallel_weights: float | list[float] | tuple[float] | None = None, |
| use_original_formulation: bool = False, |
| start: float | list[float] | tuple[float] = 0.0, |
| stop: float | list[float] | tuple[float] = 1.0, |
| guidance_rescale_space: str = "data", |
| upcast_to_double: bool = True, |
| enabled: bool = True, |
| ): |
| if not _CAN_USE_KORNIA: |
| raise ImportError( |
| "The `FrequencyDecoupledGuidance` guider cannot be instantiated because the `kornia` library on which " |
| "it depends is not available in the current environment. You can install `kornia` with `pip install " |
| "kornia`." |
| ) |
|
|
| |
| min_start = start if isinstance(start, float) else min(start) |
| max_stop = stop if isinstance(stop, float) else max(stop) |
| super().__init__(min_start, max_stop, enabled) |
|
|
| self.guidance_scales = guidance_scales |
| self.levels = len(guidance_scales) |
|
|
| if isinstance(guidance_rescale, float): |
| self.guidance_rescale = [guidance_rescale] * self.levels |
| elif len(guidance_rescale) == self.levels: |
| self.guidance_rescale = guidance_rescale |
| else: |
| raise ValueError( |
| f"`guidance_rescale` has length {len(guidance_rescale)} but should have the same length as " |
| f"`guidance_scales` ({len(self.guidance_scales)})" |
| ) |
| |
| |
| if guidance_rescale_space not in ["data", "freq"]: |
| raise ValueError( |
| f"Guidance rescale space is {guidance_rescale_space} but must be one of `data` or `freq`." |
| ) |
| self.guidance_rescale_space = guidance_rescale_space |
|
|
| if parallel_weights is None: |
| |
| self.parallel_weights = [1.0] * self.levels |
| elif isinstance(parallel_weights, float): |
| self.parallel_weights = [parallel_weights] * self.levels |
| elif len(parallel_weights) == self.levels: |
| self.parallel_weights = parallel_weights |
| else: |
| raise ValueError( |
| f"`parallel_weights` has length {len(parallel_weights)} but should have the same length as " |
| f"`guidance_scales` ({len(self.guidance_scales)})" |
| ) |
|
|
| self.use_original_formulation = use_original_formulation |
| self.upcast_to_double = upcast_to_double |
|
|
| if isinstance(start, float): |
| self.guidance_start = [start] * self.levels |
| elif len(start) == self.levels: |
| self.guidance_start = start |
| else: |
| raise ValueError( |
| f"`start` has length {len(start)} but should have the same length as `guidance_scales` " |
| f"({len(self.guidance_scales)})" |
| ) |
| if isinstance(stop, float): |
| self.guidance_stop = [stop] * self.levels |
| elif len(stop) == self.levels: |
| self.guidance_stop = stop |
| else: |
| raise ValueError( |
| f"`stop` has length {len(stop)} but should have the same length as `guidance_scales` " |
| f"({len(self.guidance_scales)})" |
| ) |
|
|
| def prepare_inputs(self, data: dict[str, tuple[torch.Tensor, torch.Tensor]]) -> list["BlockState"]: |
| tuple_indices = [0] if self.num_conditions == 1 else [0, 1] |
| data_batches = [] |
| for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): |
| data_batch = self._prepare_batch(data, tuple_idx, input_prediction) |
| data_batches.append(data_batch) |
| return data_batches |
|
|
| def prepare_inputs_from_block_state( |
| self, data: "BlockState", input_fields: dict[str, str | tuple[str, str]] |
| ) -> list["BlockState"]: |
| tuple_indices = [0] if self.num_conditions == 1 else [0, 1] |
| data_batches = [] |
| for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions): |
| data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction) |
| data_batches.append(data_batch) |
| return data_batches |
|
|
| def forward(self, pred_cond: torch.Tensor, pred_uncond: torch.Tensor | None = None) -> GuiderOutput: |
| pred = None |
|
|
| if not self._is_fdg_enabled(): |
| pred = pred_cond |
| else: |
| |
| pred_cond_pyramid = build_laplacian_pyramid_func(pred_cond, self.levels) |
| pred_uncond_pyramid = build_laplacian_pyramid_func(pred_uncond, self.levels) |
|
|
| |
| pred_guided_pyramid = [] |
| parameters = zip(self.guidance_scales, self.parallel_weights, self.guidance_rescale) |
| for level, (guidance_scale, parallel_weight, guidance_rescale) in enumerate(parameters): |
| if self._is_fdg_enabled_for_level(level): |
| |
| pred_cond_freq = pred_cond_pyramid[level] |
| pred_uncond_freq = pred_uncond_pyramid[level] |
|
|
| shift = pred_cond_freq - pred_uncond_freq |
|
|
| |
| if not math.isclose(parallel_weight, 1.0): |
| shift_parallel, shift_orthogonal = project(shift, pred_cond_freq, self.upcast_to_double) |
| shift = parallel_weight * shift_parallel + shift_orthogonal |
|
|
| |
| pred = pred_cond_freq if self.use_original_formulation else pred_uncond_freq |
| pred = pred + guidance_scale * shift |
|
|
| if self.guidance_rescale_space == "freq" and guidance_rescale > 0.0: |
| pred = rescale_noise_cfg(pred, pred_cond_freq, guidance_rescale) |
|
|
| |
| pred_guided_pyramid.append(pred) |
| else: |
| |
| pred_guided_pyramid.append(pred_cond_freq) |
|
|
| |
| pred = build_image_from_pyramid(pred_guided_pyramid) |
|
|
| |
| |
| if self.guidance_rescale_space == "data" and self.guidance_rescale[0] > 0.0: |
| pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale[0]) |
|
|
| return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond) |
|
|
| @property |
| def is_conditional(self) -> bool: |
| return self._count_prepared == 1 |
|
|
| @property |
| def num_conditions(self) -> int: |
| num_conditions = 1 |
| if self._is_fdg_enabled(): |
| num_conditions += 1 |
| return num_conditions |
|
|
| def _is_fdg_enabled(self) -> bool: |
| if not self._enabled: |
| return False |
|
|
| is_within_range = True |
| if self._num_inference_steps is not None: |
| skip_start_step = int(self._start * self._num_inference_steps) |
| skip_stop_step = int(self._stop * self._num_inference_steps) |
| is_within_range = skip_start_step <= self._step < skip_stop_step |
|
|
| is_close = False |
| if self.use_original_formulation: |
| is_close = all(math.isclose(guidance_scale, 0.0) for guidance_scale in self.guidance_scales) |
| else: |
| is_close = all(math.isclose(guidance_scale, 1.0) for guidance_scale in self.guidance_scales) |
|
|
| return is_within_range and not is_close |
|
|
| def _is_fdg_enabled_for_level(self, level: int) -> bool: |
| if not self._enabled: |
| return False |
|
|
| is_within_range = True |
| if self._num_inference_steps is not None: |
| skip_start_step = int(self.guidance_start[level] * self._num_inference_steps) |
| skip_stop_step = int(self.guidance_stop[level] * self._num_inference_steps) |
| is_within_range = skip_start_step <= self._step < skip_stop_step |
|
|
| is_close = False |
| if self.use_original_formulation: |
| is_close = math.isclose(self.guidance_scales[level], 0.0) |
| else: |
| is_close = math.isclose(self.guidance_scales[level], 1.0) |
|
|
| return is_within_range and not is_close |
|
|