xiaoanyu123's picture
Add files using upload-large-folder tool
5e7c231 verified
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Tuple, Union
import torch
from ..utils import get_logger
from ..utils.torch_utils import unwrap_module
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
from ._helpers import TransformerBlockRegistry
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
logger = get_logger(__name__) # pylint: disable=invalid-name
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
_FBC_BLOCK_HOOK = "fbc_block_hook"
@dataclass
class FirstBlockCacheConfig:
r"""
Configuration for [First Block
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
Args:
threshold (`float`, defaults to `0.05`):
The threshold to determine whether or not a forward pass through all layers of the model is required. A
higher threshold usually results in a forward pass through a lower number of layers and faster inference,
but might lead to poorer generation quality. A lower threshold may not result in significant generation
speedup. The threshold is compared against the absmean difference of the residuals between the current and
cached outputs from the first transformer block. If the difference is below the threshold, the forward pass
is skipped.
"""
threshold: float = 0.05
class FBCSharedBlockState(BaseState):
def __init__(self) -> None:
super().__init__()
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.head_block_residual: torch.Tensor = None
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
self.should_compute: bool = True
def reset(self):
self.tail_block_residuals = None
self.should_compute = True
class FBCHeadBlockHook(ModelHook):
_is_stateful = True
def __init__(self, state_manager: StateManager, threshold: float):
self.state_manager = state_manager
self.threshold = threshold
self._metadata = None
def initialize_hook(self, module):
unwrapped_module = unwrap_module(module)
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
output = self.fn_ref.original_forward(*args, **kwargs)
is_output_tuple = isinstance(output, tuple)
if is_output_tuple:
hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states
else:
hidden_states_residual = output - original_hidden_states
shared_state: FBCSharedBlockState = self.state_manager.get_state()
hidden_states = encoder_hidden_states = None
should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
shared_state.should_compute = should_compute
if not should_compute:
# Apply caching
if is_output_tuple:
hidden_states = (
shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
)
else:
hidden_states = shared_state.tail_block_residuals[0] + output
if self._metadata.return_encoder_hidden_states_index is not None:
assert is_output_tuple
encoder_hidden_states = (
shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
)
if is_output_tuple:
return_output = [None] * len(output)
return_output[self._metadata.return_hidden_states_index] = hidden_states
return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
return_output = tuple(return_output)
else:
return_output = hidden_states
output = return_output
else:
if is_output_tuple:
head_block_output = [None] * len(output)
head_block_output[0] = output[self._metadata.return_hidden_states_index]
head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
else:
head_block_output = output
shared_state.head_block_output = head_block_output
shared_state.head_block_residual = hidden_states_residual
return output
def reset_state(self, module):
self.state_manager.reset()
return module
@torch.compiler.disable
def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
shared_state = self.state_manager.get_state()
if shared_state.head_block_residual is None:
return True
prev_hidden_states_residual = shared_state.head_block_residual
absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
diff = (absmean / prev_hidden_states_absmean).item()
return diff > self.threshold
class FBCBlockHook(ModelHook):
def __init__(self, state_manager: StateManager, is_tail: bool = False):
super().__init__()
self.state_manager = state_manager
self.is_tail = is_tail
self._metadata = None
def initialize_hook(self, module):
unwrapped_module = unwrap_module(module)
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
return module
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
original_encoder_hidden_states = None
if self._metadata.return_encoder_hidden_states_index is not None:
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
"encoder_hidden_states", args, kwargs
)
shared_state = self.state_manager.get_state()
if shared_state.should_compute:
output = self.fn_ref.original_forward(*args, **kwargs)
if self.is_tail:
hidden_states_residual = encoder_hidden_states_residual = None
if isinstance(output, tuple):
hidden_states_residual = (
output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
)
encoder_hidden_states_residual = (
output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
)
else:
hidden_states_residual = output - shared_state.head_block_output
shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
return output
if original_encoder_hidden_states is None:
return_output = original_hidden_states
else:
return_output = [None, None]
return_output[self._metadata.return_hidden_states_index] = original_hidden_states
return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
return_output = tuple(return_output)
return return_output
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
"""
Applies [First Block
Cache](https://github.com/chengzeyi/ParaAttention/blob/4de137c5b96416489f06e43e19f2c14a772e28fd/README.md#first-block-cache-our-dynamic-caching)
to a given module.
First Block Cache builds on the ideas of [TeaCache](https://huggingface.co/papers/2411.19108). It is much simpler
to implement generically for a wide range of models and has been integrated first for experimental purposes.
Args:
module (`torch.nn.Module`):
The pytorch module to apply FBCache to. Typically, this should be a transformer architecture supported in
Diffusers, such as `CogVideoXTransformer3DModel`, but external implementations may also work.
config (`FirstBlockCacheConfig`):
The configuration to use for applying the FBCache method.
Example:
```python
>>> import torch
>>> from diffusers import CogView4Pipeline
>>> from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
>>> pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> apply_first_block_cache(pipe.transformer, FirstBlockCacheConfig(threshold=0.2))
>>> prompt = "A photo of an astronaut riding a horse on mars"
>>> image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
>>> image.save("output.png")
```
"""
state_manager = StateManager(FBCSharedBlockState, (), {})
remaining_blocks = []
for name, submodule in module.named_children():
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
continue
for index, block in enumerate(submodule):
remaining_blocks.append((f"{name}.{index}", block))
head_block_name, head_block = remaining_blocks.pop(0)
tail_block_name, tail_block = remaining_blocks.pop(-1)
logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
_apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
for name, block in remaining_blocks:
logger.debug(f"Applying FBCBlockHook to '{name}'")
_apply_fbc_block_hook(block, state_manager)
logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
_apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCHeadBlockHook(state_manager, threshold)
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
registry = HookRegistry.check_if_exists_or_initialize(block)
hook = FBCBlockHook(state_manager, is_tail)
registry.register_hook(hook, _FBC_BLOCK_HOOK)