Spaces:
Running on Zero
Running on Zero
| # Copyright 2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from dataclasses import dataclass | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| from ..utils import get_logger | |
| from ..utils.torch_utils import unwrap_module | |
| from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS | |
| from ._helpers import TransformerBlockRegistry | |
| from .hooks import BaseState, HookRegistry, ModelHook, StateManager | |
| logger = get_logger(__name__) # pylint: disable=invalid-name | |
| _MAG_CACHE_LEADER_BLOCK_HOOK = "mag_cache_leader_block_hook" | |
| _MAG_CACHE_BLOCK_HOOK = "mag_cache_block_hook" | |
| # Default Mag Ratios for Flux models (Dev/Schnell) are provided for convenience. | |
| # Users must explicitly pass these to the config if using Flux. | |
| # Reference: https://github.com/Zehong-Ma/MagCache | |
| FLUX_MAG_RATIOS = torch.tensor( | |
| [1.0] | |
| + [ | |
| 1.21094, | |
| 1.11719, | |
| 1.07812, | |
| 1.0625, | |
| 1.03906, | |
| 1.03125, | |
| 1.03906, | |
| 1.02344, | |
| 1.03125, | |
| 1.02344, | |
| 0.98047, | |
| 1.01562, | |
| 1.00781, | |
| 1.0, | |
| 1.00781, | |
| 1.0, | |
| 1.00781, | |
| 1.0, | |
| 1.0, | |
| 0.99609, | |
| 0.99609, | |
| 0.98047, | |
| 0.98828, | |
| 0.96484, | |
| 0.95703, | |
| 0.93359, | |
| 0.89062, | |
| ] | |
| ) | |
| def nearest_interp(src_array: torch.Tensor, target_length: int) -> torch.Tensor: | |
| """ | |
| Interpolate the source array to the target length using nearest neighbor interpolation. | |
| """ | |
| src_length = len(src_array) | |
| if target_length == 1: | |
| return src_array[-1:] | |
| scale = (src_length - 1) / (target_length - 1) | |
| grid = torch.arange(target_length, device=src_array.device, dtype=torch.float32) | |
| mapped_indices = torch.round(grid * scale).long() | |
| return src_array[mapped_indices] | |
| class MagCacheConfig: | |
| r""" | |
| Configuration for [MagCache](https://github.com/Zehong-Ma/MagCache). | |
| Args: | |
| threshold (`float`, defaults to `0.06`): | |
| The threshold for the accumulated error. If the accumulated error is below this threshold, the block | |
| computation is skipped. A higher threshold allows for more aggressive skipping (faster) but may degrade | |
| quality. | |
| max_skip_steps (`int`, defaults to `3`): | |
| The maximum number of consecutive steps that can be skipped (K in the paper). | |
| retention_ratio (`float`, defaults to `0.2`): | |
| The fraction of initial steps during which skipping is disabled to ensure stability. For example, if | |
| `num_inference_steps` is 28 and `retention_ratio` is 0.2, the first 6 steps will never be skipped. | |
| num_inference_steps (`int`, defaults to `28`): | |
| The number of inference steps used in the pipeline. This is required to interpolate `mag_ratios` correctly. | |
| mag_ratios (`torch.Tensor`, *optional*): | |
| The pre-computed magnitude ratios for the model. These are checkpoint-dependent. If not provided, you must | |
| set `calibrate=True` to calculate them for your specific model. For Flux models, you can use | |
| `diffusers.hooks.mag_cache.FLUX_MAG_RATIOS`. | |
| calibrate (`bool`, defaults to `False`): | |
| If True, enables calibration mode. In this mode, no blocks are skipped. Instead, the hook calculates the | |
| magnitude ratios for the current run and logs them at the end. Use this to obtain `mag_ratios` for new | |
| models or schedulers. | |
| """ | |
| threshold: float = 0.06 | |
| max_skip_steps: int = 3 | |
| retention_ratio: float = 0.2 | |
| num_inference_steps: int = 28 | |
| mag_ratios: Optional[Union[torch.Tensor, List[float]]] = None | |
| calibrate: bool = False | |
| def __post_init__(self): | |
| # User MUST provide ratios OR enable calibration. | |
| if self.mag_ratios is None and not self.calibrate: | |
| raise ValueError( | |
| " `mag_ratios` must be provided for MagCache inference because these ratios are model-dependent.\n" | |
| "To get them for your model:\n" | |
| "1. Initialize `MagCacheConfig(calibrate=True, ...)`\n" | |
| "2. Run inference on your model once.\n" | |
| "3. Copy the printed ratios array and pass it to `mag_ratios` in the config.\n" | |
| "For Flux models, you can import `FLUX_MAG_RATIOS` from `diffusers.hooks.mag_cache`." | |
| ) | |
| if not self.calibrate and self.mag_ratios is not None: | |
| if not torch.is_tensor(self.mag_ratios): | |
| self.mag_ratios = torch.tensor(self.mag_ratios) | |
| if len(self.mag_ratios) != self.num_inference_steps: | |
| logger.debug( | |
| f"Interpolating mag_ratios from length {len(self.mag_ratios)} to {self.num_inference_steps}" | |
| ) | |
| self.mag_ratios = nearest_interp(self.mag_ratios, self.num_inference_steps) | |
| class MagCacheState(BaseState): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| # Cache for the residual (output - input) from the *previous* timestep | |
| self.previous_residual: torch.Tensor = None | |
| # State inputs/outputs for the current forward pass | |
| self.head_block_input: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None | |
| self.should_compute: bool = True | |
| # MagCache accumulators | |
| self.accumulated_ratio: float = 1.0 | |
| self.accumulated_err: float = 0.0 | |
| self.accumulated_steps: int = 0 | |
| # Current step counter (timestep index) | |
| self.step_index: int = 0 | |
| # Calibration storage | |
| self.calibration_ratios: List[float] = [] | |
| def reset(self): | |
| self.previous_residual = None | |
| self.should_compute = True | |
| self.accumulated_ratio = 1.0 | |
| self.accumulated_err = 0.0 | |
| self.accumulated_steps = 0 | |
| self.step_index = 0 | |
| self.calibration_ratios = [] | |
| class MagCacheHeadHook(ModelHook): | |
| _is_stateful = True | |
| def __init__(self, state_manager: StateManager, config: MagCacheConfig): | |
| self.state_manager = state_manager | |
| self.config = config | |
| self._metadata = None | |
| def initialize_hook(self, module): | |
| unwrapped_module = unwrap_module(module) | |
| self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) | |
| return module | |
| def new_forward(self, module: torch.nn.Module, *args, **kwargs): | |
| if self.state_manager._current_context is None: | |
| self.state_manager.set_context("inference") | |
| arg_name = self._metadata.hidden_states_argument_name | |
| hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs) | |
| state: MagCacheState = self.state_manager.get_state() | |
| state.head_block_input = hidden_states | |
| should_compute = True | |
| if self.config.calibrate: | |
| # Never skip during calibration | |
| should_compute = True | |
| else: | |
| # MagCache Logic | |
| current_step = state.step_index | |
| if current_step >= len(self.config.mag_ratios): | |
| current_scale = 1.0 | |
| else: | |
| current_scale = self.config.mag_ratios[current_step] | |
| retention_step = int(self.config.retention_ratio * self.config.num_inference_steps + 0.5) | |
| if current_step >= retention_step: | |
| state.accumulated_ratio *= current_scale | |
| state.accumulated_steps += 1 | |
| state.accumulated_err += abs(1.0 - state.accumulated_ratio) | |
| if ( | |
| state.previous_residual is not None | |
| and state.accumulated_err <= self.config.threshold | |
| and state.accumulated_steps <= self.config.max_skip_steps | |
| ): | |
| should_compute = False | |
| else: | |
| state.accumulated_ratio = 1.0 | |
| state.accumulated_steps = 0 | |
| state.accumulated_err = 0.0 | |
| state.should_compute = should_compute | |
| if not should_compute: | |
| logger.debug(f"MagCache: Skipping step {state.step_index}") | |
| # Apply MagCache: Output = Input + Previous Residual | |
| output = hidden_states | |
| res = state.previous_residual | |
| if res.device != output.device: | |
| res = res.to(output.device) | |
| # Attempt to apply residual handling shape mismatches (e.g., text+image vs image only) | |
| if res.shape == output.shape: | |
| output = output + res | |
| elif ( | |
| output.ndim == 3 | |
| and res.ndim == 3 | |
| and output.shape[0] == res.shape[0] | |
| and output.shape[2] == res.shape[2] | |
| ): | |
| # Assuming concatenation where image part is at the end (standard in Flux/SD3) | |
| diff = output.shape[1] - res.shape[1] | |
| if diff > 0: | |
| output = output.clone() | |
| output[:, diff:, :] = output[:, diff:, :] + res | |
| else: | |
| logger.warning( | |
| f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. " | |
| "Cannot apply residual safely. Returning input without residual." | |
| ) | |
| else: | |
| logger.warning( | |
| f"MagCache: Dimension mismatch. Input {output.shape}, Residual {res.shape}. " | |
| "Cannot apply residual safely. Returning input without residual." | |
| ) | |
| if self._metadata.return_encoder_hidden_states_index is not None: | |
| original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( | |
| "encoder_hidden_states", args, kwargs | |
| ) | |
| max_idx = max( | |
| self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index | |
| ) | |
| ret_list = [None] * (max_idx + 1) | |
| ret_list[self._metadata.return_hidden_states_index] = output | |
| ret_list[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states | |
| return tuple(ret_list) | |
| else: | |
| return output | |
| else: | |
| # Compute original forward | |
| output = self.fn_ref.original_forward(*args, **kwargs) | |
| return output | |
| def reset_state(self, module): | |
| self.state_manager.reset() | |
| return module | |
| class MagCacheBlockHook(ModelHook): | |
| def __init__(self, state_manager: StateManager, is_tail: bool = False, config: MagCacheConfig = None): | |
| super().__init__() | |
| self.state_manager = state_manager | |
| self.is_tail = is_tail | |
| self.config = config | |
| self._metadata = None | |
| def initialize_hook(self, module): | |
| unwrapped_module = unwrap_module(module) | |
| self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) | |
| return module | |
| def new_forward(self, module: torch.nn.Module, *args, **kwargs): | |
| if self.state_manager._current_context is None: | |
| self.state_manager.set_context("inference") | |
| state: MagCacheState = self.state_manager.get_state() | |
| if not state.should_compute: | |
| arg_name = self._metadata.hidden_states_argument_name | |
| hidden_states = self._metadata._get_parameter_from_args_kwargs(arg_name, args, kwargs) | |
| if self.is_tail: | |
| # Still need to advance step index even if we skip | |
| self._advance_step(state) | |
| if self._metadata.return_encoder_hidden_states_index is not None: | |
| encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( | |
| "encoder_hidden_states", args, kwargs | |
| ) | |
| max_idx = max( | |
| self._metadata.return_hidden_states_index, self._metadata.return_encoder_hidden_states_index | |
| ) | |
| ret_list = [None] * (max_idx + 1) | |
| ret_list[self._metadata.return_hidden_states_index] = hidden_states | |
| ret_list[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states | |
| return tuple(ret_list) | |
| return hidden_states | |
| output = self.fn_ref.original_forward(*args, **kwargs) | |
| if self.is_tail: | |
| # Calculate residual for next steps | |
| if isinstance(output, tuple): | |
| out_hidden = output[self._metadata.return_hidden_states_index] | |
| else: | |
| out_hidden = output | |
| in_hidden = state.head_block_input | |
| if in_hidden is None: | |
| return output | |
| # Determine residual | |
| if out_hidden.shape == in_hidden.shape: | |
| residual = out_hidden - in_hidden | |
| elif out_hidden.ndim == 3 and in_hidden.ndim == 3 and out_hidden.shape[2] == in_hidden.shape[2]: | |
| diff = in_hidden.shape[1] - out_hidden.shape[1] | |
| if diff == 0: | |
| residual = out_hidden - in_hidden | |
| else: | |
| residual = out_hidden - in_hidden # Fallback to matching tail | |
| else: | |
| # Fallback for completely mismatched shapes | |
| residual = out_hidden | |
| if self.config.calibrate: | |
| self._perform_calibration_step(state, residual) | |
| state.previous_residual = residual | |
| self._advance_step(state) | |
| return output | |
| def _perform_calibration_step(self, state: MagCacheState, current_residual: torch.Tensor): | |
| if state.previous_residual is None: | |
| # First step has no previous residual to compare against. | |
| # log 1.0 as a neutral starting point. | |
| ratio = 1.0 | |
| else: | |
| # MagCache Calibration Formula: mean(norm(curr) / norm(prev)) | |
| # norm(dim=-1) gives magnitude of each token vector | |
| curr_norm = torch.linalg.norm(current_residual.float(), dim=-1) | |
| prev_norm = torch.linalg.norm(state.previous_residual.float(), dim=-1) | |
| # Avoid division by zero | |
| ratio = (curr_norm / (prev_norm + 1e-8)).mean().item() | |
| state.calibration_ratios.append(ratio) | |
| def _advance_step(self, state: MagCacheState): | |
| state.step_index += 1 | |
| if state.step_index >= self.config.num_inference_steps: | |
| # End of inference loop | |
| if self.config.calibrate: | |
| print("\n[MagCache] Calibration Complete. Copy these values to MagCacheConfig(mag_ratios=...):") | |
| print(f"{state.calibration_ratios}\n") | |
| logger.info(f"MagCache Calibration Results: {state.calibration_ratios}") | |
| # Reset state | |
| state.step_index = 0 | |
| state.accumulated_ratio = 1.0 | |
| state.accumulated_steps = 0 | |
| state.accumulated_err = 0.0 | |
| state.previous_residual = None | |
| state.calibration_ratios = [] | |
| def apply_mag_cache(module: torch.nn.Module, config: MagCacheConfig) -> None: | |
| """ | |
| Applies MagCache to a given module (typically a Transformer). | |
| Args: | |
| module (`torch.nn.Module`): | |
| The module to apply MagCache to. | |
| config (`MagCacheConfig`): | |
| The configuration for MagCache. | |
| """ | |
| # Initialize registry on the root module so the Pipeline can set context. | |
| HookRegistry.check_if_exists_or_initialize(module) | |
| state_manager = StateManager(MagCacheState, (), {}) | |
| remaining_blocks = [] | |
| for name, submodule in module.named_children(): | |
| if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): | |
| continue | |
| for index, block in enumerate(submodule): | |
| remaining_blocks.append((f"{name}.{index}", block)) | |
| if not remaining_blocks: | |
| logger.warning("MagCache: No transformer blocks found to apply hooks.") | |
| return | |
| # Handle single-block models | |
| if len(remaining_blocks) == 1: | |
| name, block = remaining_blocks[0] | |
| logger.info(f"MagCache: Applying Head+Tail Hooks to single block '{name}'") | |
| _apply_mag_cache_block_hook(block, state_manager, config, is_tail=True) | |
| _apply_mag_cache_head_hook(block, state_manager, config) | |
| return | |
| head_block_name, head_block = remaining_blocks.pop(0) | |
| tail_block_name, tail_block = remaining_blocks.pop(-1) | |
| logger.info(f"MagCache: Applying Head Hook to {head_block_name}") | |
| _apply_mag_cache_head_hook(head_block, state_manager, config) | |
| for name, block in remaining_blocks: | |
| _apply_mag_cache_block_hook(block, state_manager, config) | |
| logger.info(f"MagCache: Applying Tail Hook to {tail_block_name}") | |
| _apply_mag_cache_block_hook(tail_block, state_manager, config, is_tail=True) | |
| def _apply_mag_cache_head_hook(block: torch.nn.Module, state_manager: StateManager, config: MagCacheConfig) -> None: | |
| registry = HookRegistry.check_if_exists_or_initialize(block) | |
| # Automatically remove existing hook to allow re-application (e.g. switching modes) | |
| if registry.get_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) is not None: | |
| registry.remove_hook(_MAG_CACHE_LEADER_BLOCK_HOOK) | |
| hook = MagCacheHeadHook(state_manager, config) | |
| registry.register_hook(hook, _MAG_CACHE_LEADER_BLOCK_HOOK) | |
| def _apply_mag_cache_block_hook( | |
| block: torch.nn.Module, | |
| state_manager: StateManager, | |
| config: MagCacheConfig, | |
| is_tail: bool = False, | |
| ) -> None: | |
| registry = HookRegistry.check_if_exists_or_initialize(block) | |
| # Automatically remove existing hook to allow re-application | |
| if registry.get_hook(_MAG_CACHE_BLOCK_HOOK) is not None: | |
| registry.remove_hook(_MAG_CACHE_BLOCK_HOOK) | |
| hook = MagCacheBlockHook(state_manager, is_tail, config) | |
| registry.register_hook(hook, _MAG_CACHE_BLOCK_HOOK) | |