|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
|
from dataclasses import dataclass |
|
|
from typing import Any, Callable, List, Optional, Tuple |
|
|
|
|
|
import torch |
|
|
|
|
|
from ..models.attention_processor import Attention, MochiAttention |
|
|
from ..models.modeling_outputs import Transformer2DModelOutput |
|
|
from ..utils import logging |
|
|
from .hooks import HookRegistry, ModelHook |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
_FASTER_CACHE_DENOISER_HOOK = "faster_cache_denoiser" |
|
|
_FASTER_CACHE_BLOCK_HOOK = "faster_cache_block" |
|
|
_ATTENTION_CLASSES = (Attention, MochiAttention) |
|
|
_SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ( |
|
|
"^blocks.*attn", |
|
|
"^transformer_blocks.*attn", |
|
|
"^single_transformer_blocks.*attn", |
|
|
) |
|
|
_TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",) |
|
|
_TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS |
|
|
_UNCOND_COND_INPUT_KWARGS_IDENTIFIERS = ( |
|
|
"hidden_states", |
|
|
"encoder_hidden_states", |
|
|
"timestep", |
|
|
"attention_mask", |
|
|
"encoder_attention_mask", |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FasterCacheConfig: |
|
|
r""" |
|
|
Configuration for [FasterCache](https://huggingface.co/papers/2410.19355). |
|
|
|
|
|
Attributes: |
|
|
spatial_attention_block_skip_range (`int`, defaults to `2`): |
|
|
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will |
|
|
be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention |
|
|
states again. |
|
|
temporal_attention_block_skip_range (`int`, *optional*, defaults to `None`): |
|
|
Calculate the attention states every `N` iterations. If this is set to `N`, the attention computation will |
|
|
be skipped `N - 1` times (i.e., cached attention states will be re-used) before computing the new attention |
|
|
states again. |
|
|
spatial_attention_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 681)`): |
|
|
The timestep range within which the spatial attention computation can be skipped without a significant loss |
|
|
in quality. This is to be determined by the user based on the underlying model. The first value in the |
|
|
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for |
|
|
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at |
|
|
timestep 0). For the default values, this would mean that the spatial attention computation skipping will |
|
|
be applicable only after denoising timestep 681 is reached, and continue until the end of the denoising |
|
|
process. |
|
|
temporal_attention_timestep_skip_range (`Tuple[float, float]`, *optional*, defaults to `None`): |
|
|
The timestep range within which the temporal attention computation can be skipped without a significant |
|
|
loss in quality. This is to be determined by the user based on the underlying model. The first value in the |
|
|
tuple is the lower bound and the second value is the upper bound. Typically, diffusion timesteps for |
|
|
denoising are in the reversed range of 0 to 1000 (i.e. denoising starts at timestep 1000 and ends at |
|
|
timestep 0). |
|
|
low_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(99, 901)`): |
|
|
The timestep range within which the low frequency weight scaling update is applied. The first value in the |
|
|
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback |
|
|
function for the update is called only within this range. |
|
|
high_frequency_weight_update_timestep_range (`Tuple[int, int]`, defaults to `(-1, 301)`): |
|
|
The timestep range within which the high frequency weight scaling update is applied. The first value in the |
|
|
tuple is the lower bound and the second value is the upper bound of the timestep range. The callback |
|
|
function for the update is called only within this range. |
|
|
alpha_low_frequency (`float`, defaults to `1.1`): |
|
|
The weight to scale the low frequency updates by. This is used to approximate the unconditional branch from |
|
|
the conditional branch outputs. |
|
|
alpha_high_frequency (`float`, defaults to `1.1`): |
|
|
The weight to scale the high frequency updates by. This is used to approximate the unconditional branch |
|
|
from the conditional branch outputs. |
|
|
unconditional_batch_skip_range (`int`, defaults to `5`): |
|
|
Process the unconditional branch every `N` iterations. If this is set to `N`, the unconditional branch |
|
|
computation will be skipped `N - 1` times (i.e., cached unconditional branch states will be re-used) before |
|
|
computing the new unconditional branch states again. |
|
|
unconditional_batch_timestep_skip_range (`Tuple[float, float]`, defaults to `(-1, 641)`): |
|
|
The timestep range within which the unconditional branch computation can be skipped without a significant |
|
|
loss in quality. This is to be determined by the user based on the underlying model. The first value in the |
|
|
tuple is the lower bound and the second value is the upper bound. |
|
|
spatial_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("blocks.*attn1", "transformer_blocks.*attn1", "single_transformer_blocks.*attn1")`): |
|
|
The identifiers to match the spatial attention blocks in the model. If the name of the block contains any |
|
|
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names, |
|
|
partial layer names, or regex patterns. Matching will always be done using a regex match. |
|
|
temporal_attention_block_identifiers (`Tuple[str, ...]`, defaults to `("temporal_transformer_blocks.*attn1",)`): |
|
|
The identifiers to match the temporal attention blocks in the model. If the name of the block contains any |
|
|
of these identifiers, FasterCache will be applied to that block. This can either be the full layer names, |
|
|
partial layer names, or regex patterns. Matching will always be done using a regex match. |
|
|
attention_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`): |
|
|
The callback function to determine the weight to scale the attention outputs by. This function should take |
|
|
the attention module as input and return a float value. This is used to approximate the unconditional |
|
|
branch from the conditional branch outputs. If not provided, the default weight is 0.5 for all timesteps. |
|
|
Typically, as described in the paper, this weight should gradually increase from 0 to 1 as the inference |
|
|
progresses. Users are encouraged to experiment and provide custom weight schedules that take into account |
|
|
the number of inference steps and underlying model behaviour as denoising progresses. |
|
|
low_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`): |
|
|
The callback function to determine the weight to scale the low frequency updates by. If not provided, the |
|
|
default weight is 1.1 for timesteps within the range specified (as described in the paper). |
|
|
high_frequency_weight_callback (`Callable[[torch.nn.Module], float]`, defaults to `None`): |
|
|
The callback function to determine the weight to scale the high frequency updates by. If not provided, the |
|
|
default weight is 1.1 for timesteps within the range specified (as described in the paper). |
|
|
tensor_format (`str`, defaults to `"BCFHW"`): |
|
|
The format of the input tensors. This should be one of `"BCFHW"`, `"BFCHW"`, or `"BCHW"`. The format is |
|
|
used to split individual latent frames in order for low and high frequency components to be computed. |
|
|
is_guidance_distilled (`bool`, defaults to `False`): |
|
|
Whether the model is guidance distilled or not. If the model is guidance distilled, FasterCache will not be |
|
|
applied at the denoiser-level to skip the unconditional branch computation (as there is none). |
|
|
_unconditional_conditional_input_kwargs_identifiers (`List[str]`, defaults to `("hidden_states", "encoder_hidden_states", "timestep", "attention_mask", "encoder_attention_mask")`): |
|
|
The identifiers to match the input kwargs that contain the batchwise-concatenated unconditional and |
|
|
conditional inputs. If the name of the input kwargs contains any of these identifiers, FasterCache will |
|
|
split the inputs into unconditional and conditional branches. This must be a list of exact input kwargs |
|
|
names that contain the batchwise-concatenated unconditional and conditional inputs. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
spatial_attention_block_skip_range: int = 2 |
|
|
temporal_attention_block_skip_range: Optional[int] = None |
|
|
|
|
|
spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681) |
|
|
temporal_attention_timestep_skip_range: Tuple[int, int] = (-1, 681) |
|
|
|
|
|
|
|
|
low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901) |
|
|
high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301) |
|
|
|
|
|
|
|
|
alpha_low_frequency: float = 1.1 |
|
|
alpha_high_frequency: float = 1.1 |
|
|
|
|
|
|
|
|
unconditional_batch_skip_range: int = 5 |
|
|
unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641) |
|
|
|
|
|
spatial_attention_block_identifiers: Tuple[str, ...] = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS |
|
|
temporal_attention_block_identifiers: Tuple[str, ...] = _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS |
|
|
|
|
|
attention_weight_callback: Callable[[torch.nn.Module], float] = None |
|
|
low_frequency_weight_callback: Callable[[torch.nn.Module], float] = None |
|
|
high_frequency_weight_callback: Callable[[torch.nn.Module], float] = None |
|
|
|
|
|
tensor_format: str = "BCFHW" |
|
|
is_guidance_distilled: bool = False |
|
|
|
|
|
current_timestep_callback: Callable[[], int] = None |
|
|
|
|
|
_unconditional_conditional_input_kwargs_identifiers: List[str] = _UNCOND_COND_INPUT_KWARGS_IDENTIFIERS |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return ( |
|
|
f"FasterCacheConfig(\n" |
|
|
f" spatial_attention_block_skip_range={self.spatial_attention_block_skip_range},\n" |
|
|
f" temporal_attention_block_skip_range={self.temporal_attention_block_skip_range},\n" |
|
|
f" spatial_attention_timestep_skip_range={self.spatial_attention_timestep_skip_range},\n" |
|
|
f" temporal_attention_timestep_skip_range={self.temporal_attention_timestep_skip_range},\n" |
|
|
f" low_frequency_weight_update_timestep_range={self.low_frequency_weight_update_timestep_range},\n" |
|
|
f" high_frequency_weight_update_timestep_range={self.high_frequency_weight_update_timestep_range},\n" |
|
|
f" alpha_low_frequency={self.alpha_low_frequency},\n" |
|
|
f" alpha_high_frequency={self.alpha_high_frequency},\n" |
|
|
f" unconditional_batch_skip_range={self.unconditional_batch_skip_range},\n" |
|
|
f" unconditional_batch_timestep_skip_range={self.unconditional_batch_timestep_skip_range},\n" |
|
|
f" spatial_attention_block_identifiers={self.spatial_attention_block_identifiers},\n" |
|
|
f" temporal_attention_block_identifiers={self.temporal_attention_block_identifiers},\n" |
|
|
f" tensor_format={self.tensor_format},\n" |
|
|
f")" |
|
|
) |
|
|
|
|
|
|
|
|
class FasterCacheDenoiserState: |
|
|
r""" |
|
|
State for [FasterCache](https://huggingface.co/papers/2410.19355) top-level denoiser module. |
|
|
""" |
|
|
|
|
|
def __init__(self) -> None: |
|
|
self.iteration: int = 0 |
|
|
self.low_frequency_delta: torch.Tensor = None |
|
|
self.high_frequency_delta: torch.Tensor = None |
|
|
|
|
|
def reset(self): |
|
|
self.iteration = 0 |
|
|
self.low_frequency_delta = None |
|
|
self.high_frequency_delta = None |
|
|
|
|
|
|
|
|
class FasterCacheBlockState: |
|
|
r""" |
|
|
State for [FasterCache](https://huggingface.co/papers/2410.19355). Every underlying block that FasterCache is |
|
|
applied to will have an instance of this state. |
|
|
""" |
|
|
|
|
|
def __init__(self) -> None: |
|
|
self.iteration: int = 0 |
|
|
self.batch_size: int = None |
|
|
self.cache: Tuple[torch.Tensor, torch.Tensor] = None |
|
|
|
|
|
def reset(self): |
|
|
self.iteration = 0 |
|
|
self.batch_size = None |
|
|
self.cache = None |
|
|
|
|
|
|
|
|
class FasterCacheDenoiserHook(ModelHook): |
|
|
_is_stateful = True |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
unconditional_batch_skip_range: int, |
|
|
unconditional_batch_timestep_skip_range: Tuple[int, int], |
|
|
tensor_format: str, |
|
|
is_guidance_distilled: bool, |
|
|
uncond_cond_input_kwargs_identifiers: List[str], |
|
|
current_timestep_callback: Callable[[], int], |
|
|
low_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor], |
|
|
high_frequency_weight_callback: Callable[[torch.nn.Module], torch.Tensor], |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.unconditional_batch_skip_range = unconditional_batch_skip_range |
|
|
self.unconditional_batch_timestep_skip_range = unconditional_batch_timestep_skip_range |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.uncond_cond_input_kwargs_identifiers = uncond_cond_input_kwargs_identifiers |
|
|
self.tensor_format = tensor_format |
|
|
self.is_guidance_distilled = is_guidance_distilled |
|
|
|
|
|
self.current_timestep_callback = current_timestep_callback |
|
|
self.low_frequency_weight_callback = low_frequency_weight_callback |
|
|
self.high_frequency_weight_callback = high_frequency_weight_callback |
|
|
|
|
|
def initialize_hook(self, module): |
|
|
self.state = FasterCacheDenoiserState() |
|
|
return module |
|
|
|
|
|
@staticmethod |
|
|
def _get_cond_input(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
|
|
|
|
|
_, cond = input.chunk(2, dim=0) |
|
|
return cond |
|
|
|
|
|
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_within_timestep_range = ( |
|
|
self.unconditional_batch_timestep_skip_range[0] |
|
|
< self.current_timestep_callback() |
|
|
< self.unconditional_batch_timestep_skip_range[1] |
|
|
) |
|
|
should_skip_uncond = ( |
|
|
self.state.iteration > 0 |
|
|
and is_within_timestep_range |
|
|
and self.state.iteration % self.unconditional_batch_skip_range != 0 |
|
|
and not self.is_guidance_distilled |
|
|
) |
|
|
|
|
|
if should_skip_uncond: |
|
|
is_any_kwarg_uncond = any(k in self.uncond_cond_input_kwargs_identifiers for k in kwargs.keys()) |
|
|
if is_any_kwarg_uncond: |
|
|
logger.debug("FasterCache - Skipping unconditional branch computation") |
|
|
args = tuple([self._get_cond_input(arg) if torch.is_tensor(arg) else arg for arg in args]) |
|
|
kwargs = { |
|
|
k: v if k not in self.uncond_cond_input_kwargs_identifiers else self._get_cond_input(v) |
|
|
for k, v in kwargs.items() |
|
|
} |
|
|
|
|
|
output = self.fn_ref.original_forward(*args, **kwargs) |
|
|
|
|
|
if self.is_guidance_distilled: |
|
|
self.state.iteration += 1 |
|
|
return output |
|
|
|
|
|
if torch.is_tensor(output): |
|
|
hidden_states = output |
|
|
elif isinstance(output, (tuple, Transformer2DModelOutput)): |
|
|
hidden_states = output[0] |
|
|
|
|
|
batch_size = hidden_states.size(0) |
|
|
|
|
|
if should_skip_uncond: |
|
|
self.state.low_frequency_delta = self.state.low_frequency_delta * self.low_frequency_weight_callback( |
|
|
module |
|
|
) |
|
|
self.state.high_frequency_delta = self.state.high_frequency_delta * self.high_frequency_weight_callback( |
|
|
module |
|
|
) |
|
|
|
|
|
if self.tensor_format == "BCFHW": |
|
|
hidden_states = hidden_states.permute(0, 2, 1, 3, 4) |
|
|
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW": |
|
|
hidden_states = hidden_states.flatten(0, 1) |
|
|
|
|
|
low_freq_cond, high_freq_cond = _split_low_high_freq(hidden_states.float()) |
|
|
|
|
|
|
|
|
low_freq_uncond = self.state.low_frequency_delta + low_freq_cond |
|
|
high_freq_uncond = self.state.high_frequency_delta + high_freq_cond |
|
|
uncond_freq = low_freq_uncond + high_freq_uncond |
|
|
|
|
|
uncond_states = torch.fft.ifftshift(uncond_freq) |
|
|
uncond_states = torch.fft.ifft2(uncond_states).real |
|
|
|
|
|
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW": |
|
|
uncond_states = uncond_states.unflatten(0, (batch_size, -1)) |
|
|
hidden_states = hidden_states.unflatten(0, (batch_size, -1)) |
|
|
if self.tensor_format == "BCFHW": |
|
|
uncond_states = uncond_states.permute(0, 2, 1, 3, 4) |
|
|
hidden_states = hidden_states.permute(0, 2, 1, 3, 4) |
|
|
|
|
|
|
|
|
uncond_states = uncond_states.to(hidden_states.dtype) |
|
|
hidden_states = torch.cat([uncond_states, hidden_states], dim=0) |
|
|
else: |
|
|
uncond_states, cond_states = hidden_states.chunk(2, dim=0) |
|
|
if self.tensor_format == "BCFHW": |
|
|
uncond_states = uncond_states.permute(0, 2, 1, 3, 4) |
|
|
cond_states = cond_states.permute(0, 2, 1, 3, 4) |
|
|
if self.tensor_format == "BCFHW" or self.tensor_format == "BFCHW": |
|
|
uncond_states = uncond_states.flatten(0, 1) |
|
|
cond_states = cond_states.flatten(0, 1) |
|
|
|
|
|
low_freq_uncond, high_freq_uncond = _split_low_high_freq(uncond_states.float()) |
|
|
low_freq_cond, high_freq_cond = _split_low_high_freq(cond_states.float()) |
|
|
self.state.low_frequency_delta = low_freq_uncond - low_freq_cond |
|
|
self.state.high_frequency_delta = high_freq_uncond - high_freq_cond |
|
|
|
|
|
self.state.iteration += 1 |
|
|
if torch.is_tensor(output): |
|
|
output = hidden_states |
|
|
elif isinstance(output, tuple): |
|
|
output = (hidden_states, *output[1:]) |
|
|
else: |
|
|
output.sample = hidden_states |
|
|
|
|
|
return output |
|
|
|
|
|
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: |
|
|
self.state.reset() |
|
|
return module |
|
|
|
|
|
|
|
|
class FasterCacheBlockHook(ModelHook): |
|
|
_is_stateful = True |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
block_skip_range: int, |
|
|
timestep_skip_range: Tuple[int, int], |
|
|
is_guidance_distilled: bool, |
|
|
weight_callback: Callable[[torch.nn.Module], float], |
|
|
current_timestep_callback: Callable[[], int], |
|
|
) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.block_skip_range = block_skip_range |
|
|
self.timestep_skip_range = timestep_skip_range |
|
|
self.is_guidance_distilled = is_guidance_distilled |
|
|
|
|
|
self.weight_callback = weight_callback |
|
|
self.current_timestep_callback = current_timestep_callback |
|
|
|
|
|
def initialize_hook(self, module): |
|
|
self.state = FasterCacheBlockState() |
|
|
return module |
|
|
|
|
|
def _compute_approximated_attention_output( |
|
|
self, t_2_output: torch.Tensor, t_output: torch.Tensor, weight: float, batch_size: int |
|
|
) -> torch.Tensor: |
|
|
if t_2_output.size(0) != batch_size: |
|
|
|
|
|
|
|
|
assert t_2_output.size(0) == 2 * batch_size |
|
|
t_2_output = t_2_output[batch_size:] |
|
|
if t_output.size(0) != batch_size: |
|
|
|
|
|
|
|
|
assert t_output.size(0) == 2 * batch_size |
|
|
t_output = t_output[batch_size:] |
|
|
return t_output + (t_output - t_2_output) * weight |
|
|
|
|
|
def new_forward(self, module: torch.nn.Module, *args, **kwargs) -> Any: |
|
|
batch_size = [ |
|
|
*[arg.size(0) for arg in args if torch.is_tensor(arg)], |
|
|
*[v.size(0) for v in kwargs.values() if torch.is_tensor(v)], |
|
|
][0] |
|
|
if self.state.batch_size is None: |
|
|
|
|
|
self.state.batch_size = batch_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_within_timestep_range = ( |
|
|
self.timestep_skip_range[0] < self.current_timestep_callback() < self.timestep_skip_range[1] |
|
|
) |
|
|
if not is_within_timestep_range: |
|
|
should_skip_attention = False |
|
|
else: |
|
|
should_compute_attention = self.state.iteration > 0 and self.state.iteration % self.block_skip_range == 0 |
|
|
should_skip_attention = not should_compute_attention |
|
|
if should_skip_attention: |
|
|
should_skip_attention = self.is_guidance_distilled or self.state.batch_size != batch_size |
|
|
|
|
|
if should_skip_attention: |
|
|
logger.debug("FasterCache - Skipping attention and using approximation") |
|
|
if torch.is_tensor(self.state.cache[-1]): |
|
|
t_2_output, t_output = self.state.cache |
|
|
weight = self.weight_callback(module) |
|
|
output = self._compute_approximated_attention_output(t_2_output, t_output, weight, batch_size) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = () |
|
|
for t_2_output, t_output in zip(*self.state.cache): |
|
|
result = self._compute_approximated_attention_output( |
|
|
t_2_output, t_output, self.weight_callback(module), batch_size |
|
|
) |
|
|
output += (result,) |
|
|
else: |
|
|
logger.debug("FasterCache - Computing attention") |
|
|
output = self.fn_ref.original_forward(*args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if torch.is_tensor(output): |
|
|
cache_output = output |
|
|
if not self.is_guidance_distilled and cache_output.size(0) == self.state.batch_size: |
|
|
|
|
|
|
|
|
cache_output = cache_output.chunk(2, dim=0)[1] |
|
|
else: |
|
|
|
|
|
cache_output = () |
|
|
for out in output: |
|
|
if not self.is_guidance_distilled and out.size(0) == self.state.batch_size: |
|
|
out = out.chunk(2, dim=0)[1] |
|
|
cache_output += (out,) |
|
|
|
|
|
if self.state.cache is None: |
|
|
self.state.cache = [cache_output, cache_output] |
|
|
else: |
|
|
self.state.cache = [self.state.cache[-1], cache_output] |
|
|
|
|
|
self.state.iteration += 1 |
|
|
return output |
|
|
|
|
|
def reset_state(self, module: torch.nn.Module) -> torch.nn.Module: |
|
|
self.state.reset() |
|
|
return module |
|
|
|
|
|
|
|
|
def apply_faster_cache(module: torch.nn.Module, config: FasterCacheConfig) -> None: |
|
|
r""" |
|
|
Applies [FasterCache](https://huggingface.co/papers/2410.19355) to a given pipeline. |
|
|
|
|
|
Args: |
|
|
pipeline (`DiffusionPipeline`): |
|
|
The diffusion pipeline to apply FasterCache to. |
|
|
config (`Optional[FasterCacheConfig]`, `optional`, defaults to `None`): |
|
|
The configuration to use for FasterCache. |
|
|
|
|
|
Example: |
|
|
```python |
|
|
>>> import torch |
|
|
>>> from diffusers import CogVideoXPipeline, FasterCacheConfig, apply_faster_cache |
|
|
|
|
|
>>> pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16) |
|
|
>>> pipe.to("cuda") |
|
|
|
|
|
>>> config = FasterCacheConfig( |
|
|
... spatial_attention_block_skip_range=2, |
|
|
... spatial_attention_timestep_skip_range=(-1, 681), |
|
|
... low_frequency_weight_update_timestep_range=(99, 641), |
|
|
... high_frequency_weight_update_timestep_range=(-1, 301), |
|
|
... spatial_attention_block_identifiers=["transformer_blocks"], |
|
|
... attention_weight_callback=lambda _: 0.3, |
|
|
... tensor_format="BFCHW", |
|
|
... ) |
|
|
>>> apply_faster_cache(pipe.transformer, config) |
|
|
``` |
|
|
""" |
|
|
|
|
|
logger.warning( |
|
|
"FasterCache is a purely experimental feature and may not work as expected. Not all models support FasterCache. " |
|
|
"The API is subject to change in future releases, with no guarantee of backward compatibility. Please report any issues at " |
|
|
"https://github.com/huggingface/diffusers/issues." |
|
|
) |
|
|
|
|
|
if config.attention_weight_callback is None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.warning( |
|
|
"No `attention_weight_callback` provided when enabling FasterCache. Defaulting to using a weight of 0.5 for all timesteps." |
|
|
) |
|
|
config.attention_weight_callback = lambda _: 0.5 |
|
|
|
|
|
if config.low_frequency_weight_callback is None: |
|
|
logger.debug( |
|
|
"Low frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper." |
|
|
) |
|
|
|
|
|
def low_frequency_weight_callback(module: torch.nn.Module) -> float: |
|
|
is_within_range = ( |
|
|
config.low_frequency_weight_update_timestep_range[0] |
|
|
< config.current_timestep_callback() |
|
|
< config.low_frequency_weight_update_timestep_range[1] |
|
|
) |
|
|
return config.alpha_low_frequency if is_within_range else 1.0 |
|
|
|
|
|
config.low_frequency_weight_callback = low_frequency_weight_callback |
|
|
|
|
|
if config.high_frequency_weight_callback is None: |
|
|
logger.debug( |
|
|
"High frequency weight callback not provided when enabling FasterCache. Defaulting to behaviour described in the paper." |
|
|
) |
|
|
|
|
|
def high_frequency_weight_callback(module: torch.nn.Module) -> float: |
|
|
is_within_range = ( |
|
|
config.high_frequency_weight_update_timestep_range[0] |
|
|
< config.current_timestep_callback() |
|
|
< config.high_frequency_weight_update_timestep_range[1] |
|
|
) |
|
|
return config.alpha_high_frequency if is_within_range else 1.0 |
|
|
|
|
|
config.high_frequency_weight_callback = high_frequency_weight_callback |
|
|
|
|
|
supported_tensor_formats = ["BCFHW", "BFCHW", "BCHW"] |
|
|
if config.tensor_format not in supported_tensor_formats: |
|
|
raise ValueError(f"`tensor_format` must be one of {supported_tensor_formats}, but got {config.tensor_format}.") |
|
|
|
|
|
_apply_faster_cache_on_denoiser(module, config) |
|
|
|
|
|
for name, submodule in module.named_modules(): |
|
|
if not isinstance(submodule, _ATTENTION_CLASSES): |
|
|
continue |
|
|
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS): |
|
|
_apply_faster_cache_on_attention_class(name, submodule, config) |
|
|
|
|
|
|
|
|
def _apply_faster_cache_on_denoiser(module: torch.nn.Module, config: FasterCacheConfig) -> None: |
|
|
hook = FasterCacheDenoiserHook( |
|
|
config.unconditional_batch_skip_range, |
|
|
config.unconditional_batch_timestep_skip_range, |
|
|
config.tensor_format, |
|
|
config.is_guidance_distilled, |
|
|
config._unconditional_conditional_input_kwargs_identifiers, |
|
|
config.current_timestep_callback, |
|
|
config.low_frequency_weight_callback, |
|
|
config.high_frequency_weight_callback, |
|
|
) |
|
|
registry = HookRegistry.check_if_exists_or_initialize(module) |
|
|
registry.register_hook(hook, _FASTER_CACHE_DENOISER_HOOK) |
|
|
|
|
|
|
|
|
def _apply_faster_cache_on_attention_class(name: str, module: Attention, config: FasterCacheConfig) -> None: |
|
|
is_spatial_self_attention = ( |
|
|
any(re.search(identifier, name) is not None for identifier in config.spatial_attention_block_identifiers) |
|
|
and config.spatial_attention_block_skip_range is not None |
|
|
and not getattr(module, "is_cross_attention", False) |
|
|
) |
|
|
is_temporal_self_attention = ( |
|
|
any(re.search(identifier, name) is not None for identifier in config.temporal_attention_block_identifiers) |
|
|
and config.temporal_attention_block_skip_range is not None |
|
|
and not module.is_cross_attention |
|
|
) |
|
|
|
|
|
block_skip_range, timestep_skip_range, block_type = None, None, None |
|
|
if is_spatial_self_attention: |
|
|
block_skip_range = config.spatial_attention_block_skip_range |
|
|
timestep_skip_range = config.spatial_attention_timestep_skip_range |
|
|
block_type = "spatial" |
|
|
elif is_temporal_self_attention: |
|
|
block_skip_range = config.temporal_attention_block_skip_range |
|
|
timestep_skip_range = config.temporal_attention_timestep_skip_range |
|
|
block_type = "temporal" |
|
|
|
|
|
if block_skip_range is None or timestep_skip_range is None: |
|
|
logger.debug( |
|
|
f'Unable to apply FasterCache to the selected layer: "{name}" because it does ' |
|
|
f"not match any of the required criteria for spatial or temporal attention layers. Note, " |
|
|
f"however, that this layer may still be valid for applying PAB. Please specify the correct " |
|
|
f"block identifiers in the configuration or use the specialized `apply_faster_cache_on_module` " |
|
|
f"function to apply FasterCache to this layer." |
|
|
) |
|
|
return |
|
|
|
|
|
logger.debug(f"Enabling FasterCache ({block_type}) for layer: {name}") |
|
|
hook = FasterCacheBlockHook( |
|
|
block_skip_range, |
|
|
timestep_skip_range, |
|
|
config.is_guidance_distilled, |
|
|
config.attention_weight_callback, |
|
|
config.current_timestep_callback, |
|
|
) |
|
|
registry = HookRegistry.check_if_exists_or_initialize(module) |
|
|
registry.register_hook(hook, _FASTER_CACHE_BLOCK_HOOK) |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def _split_low_high_freq(x): |
|
|
fft = torch.fft.fft2(x) |
|
|
fft_shifted = torch.fft.fftshift(fft) |
|
|
height, width = x.shape[-2:] |
|
|
radius = min(height, width) // 5 |
|
|
|
|
|
y_grid, x_grid = torch.meshgrid(torch.arange(height), torch.arange(width)) |
|
|
center_x, center_y = width // 2, height // 2 |
|
|
mask = (x_grid - center_x) ** 2 + (y_grid - center_y) ** 2 <= radius**2 |
|
|
|
|
|
low_freq_mask = mask.unsqueeze(0).unsqueeze(0).to(x.device) |
|
|
high_freq_mask = ~low_freq_mask |
|
|
|
|
|
low_freq_fft = fft_shifted * low_freq_mask |
|
|
high_freq_fft = fft_shifted * high_freq_mask |
|
|
|
|
|
return low_freq_fft, high_freq_fft |
|
|
|