import math import re from dataclasses import dataclass import torch import torch.nn as nn from ..utils import logging from .hooks import HookRegistry, ModelHook, StateManager logger = logging.get_logger(__name__) _TAYLORSEER_CACHE_HOOK = "taylorseer_cache" _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS = ( "^blocks.*attn", "^transformer_blocks.*attn", "^single_transformer_blocks.*attn", ) _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS = ("^temporal_transformer_blocks.*attn",) _TRANSFORMER_BLOCK_IDENTIFIERS = _SPATIAL_ATTENTION_BLOCK_IDENTIFIERS + _TEMPORAL_ATTENTION_BLOCK_IDENTIFIERS _BLOCK_IDENTIFIERS = ("^[^.]*block[^.]*\\.[^.]+$",) _PROJ_OUT_IDENTIFIERS = ("^proj_out$",) @dataclass class TaylorSeerCacheConfig: """ Configuration for TaylorSeer cache. See: https://huggingface.co/papers/2503.06923 Attributes: cache_interval (`int`, defaults to `5`): The interval between full computation steps. After a full computation, the cached (predicted) outputs are reused for this many subsequent denoising steps before refreshing with a new full forward pass. disable_cache_before_step (`int`, defaults to `3`): The denoising step index before which caching is disabled, meaning full computation is performed for the initial steps (0 to disable_cache_before_step - 1) to gather data for Taylor series approximations. During these steps, Taylor factors are updated, but caching/predictions are not applied. Caching begins at this step. disable_cache_after_step (`int`, *optional*, defaults to `None`): The denoising step index after which caching is disabled. If set, for steps >= this value, all modules run full computations without predictions or state updates, ensuring accuracy in later stages if needed. max_order (`int`, defaults to `1`): The highest order in the Taylor series expansion for approximating module outputs. Higher orders provide better approximations but increase computation and memory usage. taylor_factors_dtype (`torch.dtype`, defaults to `torch.bfloat16`): Data type used for storing and computing Taylor series factors. Lower precision reduces memory but may affect stability; higher precision improves accuracy at the cost of more memory. skip_predict_identifiers (`list[str]`, *optional*, defaults to `None`): Regex patterns (using `re.fullmatch`) for module names to place as "skip" in "cache" mode. In this mode, the module computes fully during initial or refresh steps but returns a zero tensor (matching recorded shape) during prediction steps to skip computation cheaply. cache_identifiers (`list[str]`, *optional*, defaults to `None`): Regex patterns (using `re.fullmatch`) for module names to place in Taylor-series caching mode, where outputs are approximated and cached for reuse. use_lite_mode (`bool`, *optional*, defaults to `False`): Enables a lightweight TaylorSeer variant that minimizes memory usage by applying predefined patterns for skipping and caching (e.g., skipping blocks and caching projections). This overrides any custom `inactive_identifiers` or `active_identifiers`. Notes: - Patterns are matched using `re.fullmatch` on the module name. - If `skip_predict_identifiers` or `cache_identifiers` are provided, only matching modules are hooked. - If neither is provided, all attention-like modules are hooked by default. Example of inactive and active usage: ```py def forward(x): x = self.module1(x) # inactive module: returns zeros tensor based on shape recorded during full compute x = self.module2(x) # active module: caches output here, avoiding recomputation of prior steps return x ``` """ cache_interval: int = 5 disable_cache_before_step: int = 3 disable_cache_after_step: int | None = None max_order: int = 1 taylor_factors_dtype: torch.dtype | None = torch.bfloat16 skip_predict_identifiers: list[str] | None = None cache_identifiers: list[str] | None = None use_lite_mode: bool = False def __repr__(self) -> str: return ( "TaylorSeerCacheConfig(" f"cache_interval={self.cache_interval}, " f"disable_cache_before_step={self.disable_cache_before_step}, " f"disable_cache_after_step={self.disable_cache_after_step}, " f"max_order={self.max_order}, " f"taylor_factors_dtype={self.taylor_factors_dtype}, " f"skip_predict_identifiers={self.skip_predict_identifiers}, " f"cache_identifiers={self.cache_identifiers}, " f"use_lite_mode={self.use_lite_mode})" ) class TaylorSeerState: def __init__( self, taylor_factors_dtype: torch.dtype | None = torch.bfloat16, max_order: int = 1, is_inactive: bool = False, ): self.taylor_factors_dtype = taylor_factors_dtype self.max_order = max_order self.is_inactive = is_inactive self.module_dtypes: tuple[torch.dtype, ...] = () self.last_update_step: int | None = None self.taylor_factors: dict[int, dict[int, torch.Tensor]] = {} self.inactive_shapes: tuple[tuple[int, ...], ...] | None = None self.device: torch.device | None = None self.current_step: int = -1 def reset(self) -> None: self.current_step = -1 self.last_update_step = None self.taylor_factors = {} self.inactive_shapes = None self.device = None def update( self, outputs: tuple[torch.Tensor, ...], ) -> None: self.module_dtypes = tuple(output.dtype for output in outputs) self.device = outputs[0].device if self.is_inactive: self.inactive_shapes = tuple(output.shape for output in outputs) else: for i, features in enumerate(outputs): new_factors: dict[int, torch.Tensor] = {0: features} is_first_update = self.last_update_step is None if not is_first_update: delta_step = self.current_step - self.last_update_step if delta_step == 0: raise ValueError("Delta step cannot be zero for TaylorSeer update.") # Recursive divided differences up to max_order prev_factors = self.taylor_factors.get(i, {}) for j in range(self.max_order): prev = prev_factors.get(j) if prev is None: break new_factors[j + 1] = (new_factors[j] - prev.to(features.dtype)) / delta_step self.taylor_factors[i] = { order: factor.to(self.taylor_factors_dtype) for order, factor in new_factors.items() } self.last_update_step = self.current_step @torch.compiler.disable def predict(self) -> list[torch.Tensor]: if self.last_update_step is None: raise ValueError("Cannot predict without prior initialization/update.") step_offset = self.current_step - self.last_update_step outputs = [] if self.is_inactive: if self.inactive_shapes is None: raise ValueError("Inactive shapes not set during prediction.") for i in range(len(self.module_dtypes)): outputs.append( torch.zeros( self.inactive_shapes[i], dtype=self.module_dtypes[i], device=self.device, ) ) else: if not self.taylor_factors: raise ValueError("Taylor factors empty during prediction.") num_outputs = len(self.taylor_factors) num_orders = len(self.taylor_factors[0]) for i in range(num_outputs): output_dtype = self.module_dtypes[i] taylor_factors = self.taylor_factors[i] output = torch.zeros_like(taylor_factors[0], dtype=output_dtype) for order in range(num_orders): coeff = (step_offset**order) / math.factorial(order) factor = taylor_factors[order] output = output + factor.to(output_dtype) * coeff outputs.append(output) return outputs class TaylorSeerCacheHook(ModelHook): _is_stateful = True def __init__( self, cache_interval: int, disable_cache_before_step: int, taylor_factors_dtype: torch.dtype, state_manager: StateManager, disable_cache_after_step: int | None = None, ): super().__init__() self.cache_interval = cache_interval self.disable_cache_before_step = disable_cache_before_step self.disable_cache_after_step = disable_cache_after_step self.taylor_factors_dtype = taylor_factors_dtype self.state_manager = state_manager def initialize_hook(self, module: torch.nn.Module): return module def reset_state(self, module: torch.nn.Module) -> None: """ Reset state between sampling runs. """ self.state_manager.reset() @torch.compiler.disable def _measure_should_compute(self) -> bool: state: TaylorSeerState = self.state_manager.get_state() state.current_step += 1 current_step = state.current_step is_warmup_phase = current_step < self.disable_cache_before_step is_compute_interval = (current_step - self.disable_cache_before_step - 1) % self.cache_interval == 0 is_cooldown_phase = self.disable_cache_after_step is not None and current_step >= self.disable_cache_after_step should_compute = is_warmup_phase or is_compute_interval or is_cooldown_phase return should_compute, state def new_forward(self, module: torch.nn.Module, *args, **kwargs): should_compute, state = self._measure_should_compute() if should_compute: outputs = self.fn_ref.original_forward(*args, **kwargs) wrapped_outputs = (outputs,) if isinstance(outputs, torch.Tensor) else outputs state.update(wrapped_outputs) return outputs outputs_list = state.predict() return outputs_list[0] if len(outputs_list) == 1 else tuple(outputs_list) def _resolve_patterns(config: TaylorSeerCacheConfig) -> tuple[list[str], list[str]]: """ Resolve effective inactive and active pattern lists from config + templates. """ inactive_patterns = config.skip_predict_identifiers if config.skip_predict_identifiers is not None else None active_patterns = config.cache_identifiers if config.cache_identifiers is not None else None return inactive_patterns or [], active_patterns or [] def apply_taylorseer_cache(module: torch.nn.Module, config: TaylorSeerCacheConfig): """ Applies the TaylorSeer cache to a given pipeline (typically the transformer / UNet). This function hooks selected modules in the model to enable caching or skipping based on the provided configuration, reducing redundant computations in diffusion denoising loops. Args: module (torch.nn.Module): The model subtree to apply the hooks to. config (TaylorSeerCacheConfig): Configuration for the cache. Example: ```python >>> import torch >>> from diffusers import FluxPipeline, TaylorSeerCacheConfig >>> pipe = FluxPipeline.from_pretrained( ... "black-forest-labs/FLUX.1-dev", ... torch_dtype=torch.bfloat16, ... ) >>> pipe.to("cuda") >>> config = TaylorSeerCacheConfig( ... cache_interval=5, ... max_order=1, ... disable_cache_before_step=3, ... taylor_factors_dtype=torch.float32, ... ) >>> pipe.transformer.enable_cache(config) ``` """ inactive_patterns, active_patterns = _resolve_patterns(config) active_patterns = active_patterns or _TRANSFORMER_BLOCK_IDENTIFIERS if config.use_lite_mode: logger.info("Using TaylorSeer Lite variant for cache.") active_patterns = _PROJ_OUT_IDENTIFIERS inactive_patterns = _BLOCK_IDENTIFIERS if config.skip_predict_identifiers or config.cache_identifiers: logger.warning("Lite mode overrides user patterns.") for name, submodule in module.named_modules(): matches_inactive = any(re.fullmatch(pattern, name) for pattern in inactive_patterns) matches_active = any(re.fullmatch(pattern, name) for pattern in active_patterns) if not (matches_inactive or matches_active): continue _apply_taylorseer_cache_hook( module=submodule, config=config, is_inactive=matches_inactive, ) def _apply_taylorseer_cache_hook( module: nn.Module, config: TaylorSeerCacheConfig, is_inactive: bool, ): """ Registers the TaylorSeer hook on the specified nn.Module. Args: name: Name of the module. module: The nn.Module to be hooked. config: Cache configuration. is_inactive: Whether this module should operate in "inactive" mode. """ state_manager = StateManager( TaylorSeerState, init_kwargs={ "taylor_factors_dtype": config.taylor_factors_dtype, "max_order": config.max_order, "is_inactive": is_inactive, }, ) registry = HookRegistry.check_if_exists_or_initialize(module) hook = TaylorSeerCacheHook( cache_interval=config.cache_interval, disable_cache_before_step=config.disable_cache_before_step, taylor_factors_dtype=config.taylor_factors_dtype, disable_cache_after_step=config.disable_cache_after_step, state_manager=state_manager, ) registry.register_hook(hook, _TAYLORSEER_CACHE_HOOK)