Spaces:
Running on Zero
Running on Zero
| # Copyright 2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from dataclasses import dataclass | |
| import torch | |
| from .hooks import BaseState, HookRegistry, ModelHook, StateManager | |
| _TEXT_KV_CACHE_TRANSFORMER_HOOK = "text_kv_cache_transformer" | |
| _TEXT_KV_CACHE_BLOCK_HOOK = "text_kv_cache_block" | |
| class TextKVCacheConfig: | |
| """Enable exact (lossless) text K/V caching for transformer models. | |
| Pre-computes per-block text key and value projections once before the denoising loop and reuses them across all | |
| steps. Positive and negative prompts are distinguished via a stable cache key captured by a transformer-level hook | |
| before any intermediate tensor allocations. | |
| """ | |
| pass | |
| class TextKVCacheState(BaseState): | |
| """Shared state between the transformer-level and block-level hooks. | |
| The transformer hook writes the stable ``encoder_hidden_states`` ``data_ptr()`` (captured *before* ``txt_norm``) so | |
| that block hooks can use it as a reliable cache key across denoising steps. | |
| """ | |
| def __init__(self): | |
| self.key: int | None = None | |
| def reset(self): | |
| self.key = None | |
| class TextKVCacheBlockState(BaseState): | |
| """Per-block state holding cached text key/value projections.""" | |
| def __init__(self): | |
| self.kv_cache: dict[int, tuple[torch.Tensor, torch.Tensor]] = {} | |
| def reset(self): | |
| self.kv_cache.clear() | |
| class TextKVCacheTransformerHook(ModelHook): | |
| """Captures ``encoder_hidden_states.data_ptr()`` before ``txt_norm`` | |
| and writes it to shared state for the block hooks to read.""" | |
| _is_stateful = True | |
| def __init__(self, state_manager: StateManager): | |
| super().__init__() | |
| self.state_manager = state_manager | |
| def new_forward(self, module: torch.nn.Module, *args, **kwargs): | |
| if self.state_manager._current_context is None: | |
| self.state_manager.set_context("inference") | |
| encoder_hidden_states = kwargs.get("encoder_hidden_states") | |
| if encoder_hidden_states is not None: | |
| state: TextKVCacheState = self.state_manager.get_state() | |
| state.key = encoder_hidden_states.data_ptr() | |
| return self.fn_ref.original_forward(*args, **kwargs) | |
| def reset_state(self, module: torch.nn.Module): | |
| self.state_manager.reset() | |
| return module | |
| class TextKVCacheBlockHook(ModelHook): | |
| """Caches ``(txt_key, txt_value)`` per block per unique prompt using | |
| the stable cache key from the shared state.""" | |
| _is_stateful = True | |
| def __init__(self, state_manager: StateManager, block_state_manager: StateManager): | |
| super().__init__() | |
| self.state_manager = state_manager | |
| self.block_state_manager = block_state_manager | |
| def new_forward(self, module: torch.nn.Module, *args, **kwargs): | |
| from ..models.transformers.transformer_nucleusmoe_image import _apply_rotary_emb_nucleus | |
| if self.state_manager._current_context is None: | |
| self.state_manager.set_context("inference") | |
| if self.block_state_manager._current_context is None: | |
| self.block_state_manager.set_context("inference") | |
| if "encoder_hidden_states" in kwargs: | |
| encoder_hidden_states = kwargs["encoder_hidden_states"] | |
| else: | |
| encoder_hidden_states = args[1] | |
| if "image_rotary_emb" in kwargs: | |
| image_rotary_emb = kwargs["image_rotary_emb"] | |
| elif len(args) > 3: | |
| image_rotary_emb = args[3] | |
| else: | |
| image_rotary_emb = None | |
| state: TextKVCacheState = self.state_manager.get_state() | |
| cache_key = state.key | |
| block_state: TextKVCacheBlockState = self.block_state_manager.get_state() | |
| if cache_key not in block_state.kv_cache: | |
| context = module.encoder_proj(encoder_hidden_states) | |
| attn = module.attn | |
| head_dim = attn.inner_dim // attn.heads | |
| num_kv_heads = attn.inner_kv_dim // head_dim | |
| txt_key = attn.add_k_proj(context).unflatten(-1, (num_kv_heads, -1)) | |
| txt_value = attn.add_v_proj(context).unflatten(-1, (num_kv_heads, -1)) | |
| if attn.norm_added_k is not None: | |
| txt_key = attn.norm_added_k(txt_key) | |
| if image_rotary_emb is not None: | |
| _, txt_freqs = image_rotary_emb | |
| txt_key = _apply_rotary_emb_nucleus(txt_key, txt_freqs, use_real=False) | |
| block_state.kv_cache[cache_key] = (txt_key, txt_value) | |
| txt_key, txt_value = block_state.kv_cache[cache_key] | |
| attn_kwargs = kwargs.get("attention_kwargs") or {} | |
| attn_kwargs["cached_txt_key"] = txt_key | |
| attn_kwargs["cached_txt_value"] = txt_value | |
| kwargs["attention_kwargs"] = attn_kwargs | |
| return self.fn_ref.original_forward(*args, **kwargs) | |
| def reset_state(self, module: torch.nn.Module): | |
| self.block_state_manager.reset() | |
| return module | |
| def apply_text_kv_cache(module: torch.nn.Module, config: TextKVCacheConfig) -> None: | |
| from ..models.transformers.transformer_nucleusmoe_image import NucleusMoEImageTransformerBlock | |
| HookRegistry.check_if_exists_or_initialize(module) | |
| state_manager = StateManager(TextKVCacheState) | |
| transformer_hook = TextKVCacheTransformerHook(state_manager) | |
| registry = HookRegistry.check_if_exists_or_initialize(module) | |
| registry.register_hook(transformer_hook, _TEXT_KV_CACHE_TRANSFORMER_HOOK) | |
| for _, submodule in module.named_modules(): | |
| if isinstance(submodule, NucleusMoEImageTransformerBlock): | |
| block_state_manager = StateManager(TextKVCacheBlockState) | |
| hook = TextKVCacheBlockHook(state_manager, block_state_manager) | |
| block_registry = HookRegistry.check_if_exists_or_initialize(submodule) | |
| block_registry.register_hook(hook, _TEXT_KV_CACHE_BLOCK_HOOK) | |