| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
|
|
| from ..utils import get_logger |
| from ..utils.torch_utils import unwrap_module |
| from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS |
| from ._helpers import TransformerBlockRegistry |
| from .hooks import BaseState, HookRegistry, ModelHook, StateManager |
|
|
|
|
| logger = get_logger(__name__) |
|
|
| _FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook" |
| _FBC_BLOCK_HOOK = "fbc_block_hook" |
|
|
|
|
| @dataclass |
| class FirstBlockCacheConfig: |
| r""" |
| Configuration for [First Block |
| Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching). |
| |
| Args: |
| threshold (`float`, defaults to `0.05`): |
| The threshold to determine whether or not a forward pass through all layers of the model is required. A |
| higher threshold usually results in a forward pass through a lower number of layers and faster inference, |
| but might lead to poorer generation quality. A lower threshold may not result in significant generation |
| speedup. The threshold is compared against the absmean difference of the residuals between the current and |
| cached outputs from the first transformer block. If the difference is below the threshold, the forward pass |
| is skipped. |
| """ |
|
|
| threshold: float = 0.05 |
|
|
|
|
| class FBCSharedBlockState(BaseState): |
| def __init__(self) -> None: |
| super().__init__() |
|
|
| self.head_block_output: torch.Tensor | tuple[torch.Tensor, ...] = None |
| self.head_block_residual: torch.Tensor = None |
| self.tail_block_residuals: torch.Tensor | tuple[torch.Tensor, ...] = None |
| self.should_compute: bool = True |
|
|
| def reset(self): |
| self.tail_block_residuals = None |
| self.should_compute = True |
|
|
|
|
| class FBCHeadBlockHook(ModelHook): |
| _is_stateful = True |
|
|
| def __init__(self, state_manager: StateManager, threshold: float): |
| self.state_manager = state_manager |
| self.threshold = threshold |
| self._metadata = None |
|
|
| def initialize_hook(self, module): |
| unwrapped_module = unwrap_module(module) |
| self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) |
| return module |
|
|
| def new_forward(self, module: torch.nn.Module, *args, **kwargs): |
| original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) |
|
|
| output = self.fn_ref.original_forward(*args, **kwargs) |
| is_output_tuple = isinstance(output, tuple) |
|
|
| if is_output_tuple: |
| hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states |
| else: |
| hidden_states_residual = output - original_hidden_states |
|
|
| shared_state: FBCSharedBlockState = self.state_manager.get_state() |
| hidden_states = encoder_hidden_states = None |
| should_compute = self._should_compute_remaining_blocks(hidden_states_residual) |
| shared_state.should_compute = should_compute |
|
|
| if not should_compute: |
| |
| if is_output_tuple: |
| hidden_states = ( |
| shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index] |
| ) |
| else: |
| hidden_states = shared_state.tail_block_residuals[0] + output |
|
|
| if self._metadata.return_encoder_hidden_states_index is not None: |
| assert is_output_tuple |
| encoder_hidden_states = ( |
| shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index] |
| ) |
|
|
| if is_output_tuple: |
| return_output = [None] * len(output) |
| return_output[self._metadata.return_hidden_states_index] = hidden_states |
| return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states |
| return_output = tuple(return_output) |
| else: |
| return_output = hidden_states |
| output = return_output |
| else: |
| if is_output_tuple: |
| head_block_output = [None] * len(output) |
| head_block_output[0] = output[self._metadata.return_hidden_states_index] |
| head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index] |
| else: |
| head_block_output = output |
| shared_state.head_block_output = head_block_output |
| shared_state.head_block_residual = hidden_states_residual |
|
|
| return output |
|
|
| def reset_state(self, module): |
| self.state_manager.reset() |
| return module |
|
|
| @torch.compiler.disable |
| def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool: |
| shared_state = self.state_manager.get_state() |
| if shared_state.head_block_residual is None: |
| return True |
| prev_hidden_states_residual = shared_state.head_block_residual |
| absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean() |
| prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean() |
| diff = (absmean / prev_hidden_states_absmean).item() |
| return diff > self.threshold |
|
|
|
|
| class FBCBlockHook(ModelHook): |
| def __init__(self, state_manager: StateManager, is_tail: bool = False): |
| super().__init__() |
| self.state_manager = state_manager |
| self.is_tail = is_tail |
| self._metadata = None |
|
|
| def initialize_hook(self, module): |
| unwrapped_module = unwrap_module(module) |
| self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__) |
| return module |
|
|
| def new_forward(self, module: torch.nn.Module, *args, **kwargs): |
| original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs) |
| original_encoder_hidden_states = None |
| if self._metadata.return_encoder_hidden_states_index is not None: |
| original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs( |
| "encoder_hidden_states", args, kwargs |
| ) |
|
|
| shared_state = self.state_manager.get_state() |
|
|
| if shared_state.should_compute: |
| output = self.fn_ref.original_forward(*args, **kwargs) |
| if self.is_tail: |
| hidden_states_residual = encoder_hidden_states_residual = None |
| if isinstance(output, tuple): |
| hidden_states_residual = ( |
| output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0] |
| ) |
| encoder_hidden_states_residual = ( |
| output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1] |
| ) |
| else: |
| hidden_states_residual = output - shared_state.head_block_output |
| shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual) |
| return output |
|
|
| if original_encoder_hidden_states is None: |
| return_output = original_hidden_states |
| else: |
| return_output = [None, None] |
| return_output[self._metadata.return_hidden_states_index] = original_hidden_states |
| return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states |
| return_output = tuple(return_output) |
| return return_output |
|
|
|
|
| def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None: |
| """ |
| Applies [First Block |
| Cache](https://github.com/chengzeyi/ParaAttention/blob/4de137c5b96416489f06e43e19f2c14a772e28fd/README.md#first-block-cache-our-dynamic-caching) |
| to a given module. |
| |
| First Block Cache builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler |
| to implement generically for a wide range of models and has been integrated first for experimental purposes. |
| |
| Args: |
| module (`torch.nn.Module`): |
| The pytorch module to apply FBCache to. Typically, this should be a transformer architecture supported in |
| Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work. |
| config (`FirstBlockCacheConfig`): |
| The configuration to use for applying the FBCache method. |
| |
| Example: |
| ```python |
| >>> import torch |
| >>> from diffusers import CogView4Pipeline |
| >>> from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig |
| |
| >>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16) |
| >>> pipe.to("cuda") |
| |
| >>> apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2)) |
| |
| >>> prompt = "A photo of an astronaut riding a horse on mars" |
| >>> image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0] |
| >>> image.save("output.png") |
| ``` |
| """ |
|
|
| state_manager = StateManager(FBCSharedBlockState, (), {}) |
| remaining_blocks = [] |
|
|
| for name, submodule in module.named_children(): |
| if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList): |
| continue |
| for index, block in enumerate(submodule): |
| remaining_blocks.append((f"{name}.{index}", block)) |
|
|
| head_block_name, head_block = remaining_blocks.pop(0) |
| tail_block_name, tail_block = remaining_blocks.pop(-1) |
|
|
| logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'") |
| _apply_fbc_head_block_hook(head_block, state_manager, config.threshold) |
|
|
| for name, block in remaining_blocks: |
| logger.debug(f"Applying FBCBlockHook to '{name}'") |
| _apply_fbc_block_hook(block, state_manager) |
|
|
| logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'") |
| _apply_fbc_block_hook(tail_block, state_manager, is_tail=True) |
|
|
|
|
| def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None: |
| registry = HookRegistry.check_if_exists_or_initialize(block) |
| hook = FBCHeadBlockHook(state_manager, threshold) |
| registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK) |
|
|
|
|
| def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None: |
| registry = HookRegistry.check_if_exists_or_initialize(block) |
| hook = FBCBlockHook(state_manager, is_tail) |
| registry.register_hook(hook, _FBC_BLOCK_HOOK) |
|
|